{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "inclusive-adjustment",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from sagemaker.pytorch import PyTorch\n",
    "from utils import (wait_till_all_done, CLUSTER_Augmented_DATASETS_CTXT_20, CLUSTER_Augmented_DATASETS_CTXT_10, CLUSTER_Augmented_DATASETS_CTXT_CHAR_10,\n",
    "                   CLUSTER_Augmented_DATASETS_CTXT_CHAR_20, CLUSTER_Augmented_DATASETS_WDEL_20, CLUSTER_Augmented_DATASETS_WDEL_10)\n",
    "\n",
    "role = 'arn:aws:iam::157264205850:role/dejiao-sagemaker-run'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "guilty-fifty",
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_models = [\"distilbert\"]\n",
    "lr_params = [(5e-06, 100), (1e-05, 100)]\n",
    "contrast_types = [\"Orig\"]\n",
    "temps = [0.5]\n",
    "objectives = [\"contrastive\", \"SCCL\"]\n",
    "datasets = [\"agnews\", \"searchsnippets\", \"stackoverflow\", \"biomedical\", \"tweet\", \"googleT\", \"googleS\", \"googleTS\"]\n",
    "\n",
    "use_pretrain=\"SBERT\"\n",
    "augtype=\"explicit\"\n",
    "batch_size = 400\n",
    "maxlen = 32\n",
    "maxiter = 3000\n",
    "eta = 10\n",
    "alpha = 1.0\n",
    "base_job_name = \"SCCLv2-distil-exp-strategy-hpo-long\"\n",
    "s3_dataroot = \"s3://dejiao-experiment-east1/datasets/psc_shorttext/\"\n",
    "s3_resdir = \"s3://dejiao-experiment-east1/train/SCCL-SBERT-EXP-ALL-LONG/\"\n",
    "\n",
    "# augmentation_stratgies = [\n",
    "#     CLUSTER_Augmented_DATASETS_CTXT_20, \n",
    "#     CLUSTER_Augmented_DATASETS_CTXT_CHAR_20,\n",
    "#     CLUSTER_Augmented_DATASETS_WDEL_20, \n",
    "#     CLUSTER_Augmented_DATASETS_WDEL_10, \n",
    "#     CLUSTER_Augmented_DATASETS_CTXT_10, \n",
    "# ]\n",
    "\n",
    "augmentation_stratgies = [\n",
    "    CLUSTER_Augmented_DATASETS_CTXT_CHAR_10,\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afraid-ranch",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n",
      "submit: 1\n",
      "distilbert \t lr: 5e-06\n",
      "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n",
      "submit: 2\n",
      "distilbert \t lr: 5e-06\n",
      "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n",
      "submit: 3\n",
      "distilbert \t lr: 1e-05\n",
      "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n",
      "submit: 4\n",
      "distilbert \t lr: 1e-05\n",
      "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n",
      "submit: 5\n",
      "distilbert \t lr: 5e-06\n",
      "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n",
      "submit: 6\n",
      "distilbert \t lr: 5e-06\n",
      "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n",
      "submit: 7\n",
      "distilbert \t lr: 1e-05\n",
      "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n",
      "submit: 8\n",
      "distilbert \t lr: 1e-05\n",
      "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 9\n",
      "distilbert \t lr: 5e-06\n",
      "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 10\n",
      "distilbert \t lr: 5e-06\n",
      "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 11\n",
      "distilbert \t lr: 1e-05\n",
      "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 12\n",
      "distilbert \t lr: 1e-05\n",
      "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 13\n",
      "distilbert \t lr: 5e-06\n",
      "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 14\n",
      "distilbert \t lr: 5e-06\n",
      "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 15\n",
      "distilbert \t lr: 1e-05\n",
      "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n",
      "submit: 16\n",
      "distilbert \t lr: 1e-05\n",
      "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n",
      "submit: 17\n",
      "distilbert \t lr: 5e-06\n",
      "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n",
      "submit: 18\n",
      "distilbert \t lr: 5e-06\n",
      "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n",
      "submit: 19\n",
      "distilbert \t lr: 1e-05\n",
      "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n",
      "submit: 20\n",
      "distilbert \t lr: 1e-05\n",
      "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 21\n",
      "distilbert \t lr: 5e-06\n",
      "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 22\n",
      "distilbert \t lr: 5e-06\n",
      "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 23\n",
      "distilbert \t lr: 1e-05\n",
      "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 24\n",
      "distilbert \t lr: 1e-05\n",
      "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 25\n",
      "distilbert \t lr: 5e-06\n",
      "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 26\n",
      "distilbert \t lr: 5e-06\n",
      "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 27\n",
      "distilbert \t lr: 1e-05\n",
      "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 28\n",
      "distilbert \t lr: 1e-05\n",
      "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 29\n",
      "distilbert \t lr: 5e-06\n",
      "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 30\n",
      "distilbert \t lr: 5e-06\n",
      "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 31\n",
      "distilbert \t lr: 1e-05\n",
      "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n",
      "submit: 32\n",
      "distilbert \t lr: 1e-05\n"
     ]
    }
   ],
   "source": [
    "idx = 1  \n",
    "\n",
    "for CLUSTER_Augmented_DATASETS in augmentation_stratgies:\n",
    "    wait_till_all_done(base_job_name) \n",
    "    for datakey in datasets:\n",
    "        \n",
    "        for lr, lr_scale in lr_params:\n",
    "            for temperature in temps:\n",
    "                for objective in objectives:\n",
    "                    for ctype in contrast_types:\n",
    "                        for bert in bert_models:\n",
    "                            \n",
    "                            dataname, num_classes, text, label = CLUSTER_Augmented_DATASETS[datakey]\n",
    "                            \n",
    "                            if datakey in [\"stackoverflow\", \"biomedical\"]:\n",
    "                                alpha = 10.0\n",
    "                            else:\n",
    "                                alpha = 1.0\n",
    "                                \n",
    "                            print(f\"{dataname} \\t {num_classes} \\t {text} \\t alpha:{alpha} \")\n",
    "\n",
    "                            hyperparameters = {\n",
    "                                'train_instance': \"sagemaker\",\n",
    "                                'use_pretrain': use_pretrain,\n",
    "                                'datapath': s3_dataroot,\n",
    "                                'dataname': dataname, \n",
    "                                'text': text,\n",
    "                                'label': label,\n",
    "                                'num_classes': num_classes,\n",
    "                                'bert': bert,\n",
    "                                'objective': objective,\n",
    "                                'alpha': alpha,\n",
    "                                'eta': eta, \n",
    "                                'augtype': augtype,\n",
    "                                'contrast_type': ctype,\n",
    "                                'lr': lr,\n",
    "                                'lr_scale': lr_scale,\n",
    "                                'lr_scale_contrast': '100',\n",
    "                                'batch_size': batch_size,\n",
    "                                'max_length': maxlen,\n",
    "                                'temperature': temperature,\n",
    "                                'max_iter': maxiter,\n",
    "                                'print_freq': '100',\n",
    "                                'seed': '0',\n",
    "                                'gpuid': '0',\n",
    "                                'resdir': '/tmp/resnli/PaperTempRes/',\n",
    "                                's3_resdir': s3_resdir,\n",
    "                            }\n",
    "\n",
    "                            try:\n",
    "                                estimator = PyTorch(entry_point='main.py',\n",
    "                                                    source_dir='/home/ec2-user/efs/dejiao-explore/code/SCCL/',\n",
    "                                                    role=role,\n",
    "                                                    instance_count=1,\n",
    "                                                    instance_type='ml.p3.2xlarge',\n",
    "                                                    image_uri='157264205850.dkr.ecr.us-east-1.amazonaws.com/vncl-transformers-p17',\n",
    "                                                    base_job_name = base_job_name,\n",
    "                                                    hyperparameters=hyperparameters,\n",
    "                                                    output_path='s3://dejiao-sagemaker-east1/SCCL/',\n",
    "                                                    framework_version='1.8.1',\n",
    "                                                    py_version = 'py3',\n",
    "                                                    debugger_hook_config=False,\n",
    "                                                    max_run=3 * 24 * 60 * 60,\n",
    "                                                    volume_size = 500,\n",
    "                                                    )\n",
    "\n",
    "                                estimator.fit(wait=False)\n",
    "                                print(\"submit: {}\".format(idx))\n",
    "                            except:\n",
    "                                print(\"submit: {} failed\".format(idx))\n",
    "\n",
    "                            time.sleep(2)\n",
    "                            idx += 1\n",
    "\n",
    "                            print(bert, \"\\t lr:\", lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "accepting-insight",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_pytorch_latest_p37)",
   "language": "python",
   "name": "conda_pytorch_latest_p37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
